Seeing is believing

Using FlashTorch 🔦 to shine a light on what neural nets "see"


by Misa Ogura

Hello, I'm Misa 👋


  • Originally from Tokyo, now based in London
  • Cancer Cell Biologist, turned Software Engineer
  • Currently at BBC R&D
  • Co-founder of Women Driven Development
  • Women in Data Science London Ambassador

Convolutional Neural Network (CNN)


Kernel & Convolution


Kernel: a small matrix used for blurring, sharpening, embossing, edge detection etc

Convolution: adding each element of the image to its local neighbors, weighted by the kernel

In [9]:
fig = plt.figure(figsize=(18, 6))
ax = fig.add_subplot(1, 3, 1, xticks=[], yticks=[])
ax.imshow(image, cmap='gray')
ax.set_title('Original image')

sobel_x = np.array([[ -1, 0, 1], 
                    [ -2, 0, 2], 
                    [ -1, 0, 1]])

sobel_y = np.array([[ -1, -2, -1], 
                    [ 0, 0, 0], 
                    [ 1, 2, 1]])

kernels = {'Sobel x': sobel_x, 'Sobel y': sobel_y}

for i, (title, kernel) in enumerate(kernels.items()):
    filtered_img = cv2.filter2D(image, -1, kernel)
    
    ax = fig.add_subplot(1, 3, i+2, xticks=[], yticks=[])
    ax.imshow(filtered_img, cmap='gray')
    ax.set_title(title)

Typical CNN Architecture


CNN Visualisation Techniques


Saliency map


Activation maximisation


Demo 1

Introducing FlashTorch & how to visualise saliency maps


First things first...

$ pip install flashtorch

Load an image


In [2]:
from flashtorch.utils import load_image

image = load_image('../examples/images/great_grey_owl_01.jpg')

plt.imshow(image)
plt.title('Original image')
plt.axis('off');

Convert the PIL image to a torch tensor


In [3]:
from flashtorch.utils import apply_transforms

input_ = apply_transforms(image)

print(f'Before: {type(image)}')
print(f'After: {type(input_)}, {input_.shape}')
Before: <class 'PIL.Image.Image'>
After: <class 'torch.Tensor'>, torch.Size([1, 3, 224, 224])

Let's visualise the input first...


In [4]:
# plt.imshow(input_)
# plt.title('Input tensor')
# plt.axis('off');
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

Let's visualise the input - take two


In [5]:
from flashtorch.utils import format_for_plotting

plt.imshow(format_for_plotting(input_))
plt.title('Input tensor')
plt.axis('off');
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Let's visualise the input - take THREE


In [6]:
from flashtorch.utils import denormalize

plt.imshow(format_for_plotting(denormalize(input_)))
plt.title('Input tensor')
plt.axis('off');

Load a pre-trained model & create a backprop object


In [7]:
from flashtorch.saliency import Backprop

model = models.alexnet(pretrained=True)

backprop = Backprop(model)
Signature:

    backprop.calculate_gradients(input_, target_class=None, take_max=False)

Retrieve the class index for the object in the input


In [8]:
from flashtorch.utils import ImageNetIndex 

imagenet = ImageNetIndex()
target_class = imagenet['great grey owl']

print(target_class)
24

It kind of does fuzzy-matching as well (to some extent...)


In [9]:
# imagenet['dog']
ValueError: Multiple potential matches found: maltese dog, old english sheepdog, shetland sheepdog, greater swiss mountain dog, bernese mountain dog, french bulldog, eskimo dog, african hunting dog, dogsled, hotdog

Finally, time to calculate the gradients!


In [10]:
gradients = backprop.calculate_gradients(input_, target_class)

print(type(gradients), gradients.shape)
<class 'torch.Tensor'> torch.Size([3, 224, 224])

You can also take the maximum of the gradients across colour channels


In [11]:
max_gradients = backprop.calculate_gradients(input_, target_class, take_max=True)

print(type(max_gradients), max_gradients.shape)
<class 'torch.Tensor'> torch.Size([1, 224, 224])

Let's inspect gradients by plotting them out


In [12]:
from flashtorch.utils import visualize

visualize(input_, gradients, max_gradients)

Pixels where the animal is present have the strongest positive effects.

But it's quite noisy...

Guided backprop to the rescue!


In [13]:
guided_gradients = backprop.calculate_gradients(input_, target_class, guided=True)
max_guided_gradients = backprop.calculate_gradients(input_, target_class, take_max=True, guided=True)
In [14]:
visualize(input_, guided_gradients, max_guided_gradients)

Now that's much less noisy!

Pixels around the head and eyes have the strongest positive effects.

What about a jay?


In [16]:
visualize(input_, guided_gradients, max_guided_gradients)

Or an oystercatcher...


In [18]:
visualize(input_, guided_gradients, max_guided_gradients)

Demo 2

Using FlashTorch to gain additional insights on transfer learning


Transfer Learning


  • A model developed for a task is reused as a starting point another task

  • Pre-trained models often used in computer visions & natural language processing tasks

  • Save compute & time resources

Flower Classifier


From: Densenet model, pre-trained on ImageNet (1000 classes)

To: Flower classifier to recognise 102 species of flowers, using a dataset from VGG group.

In [20]:
image = load_image('../examples/images/foxglove.jpg')
input_ = apply_transforms(image)

class_index = 96  # foxglove

pretrained_model = create_model()

backprop = Backprop(pretrained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)
/Users/misao/Projects/personal/flashtorch/flashtorch/saliency/backprop.py:93: UserWarning: The predicted class does not equal the
                target class. Calculating the gradient with respect to the
                predicted class.
  predicted class.'''))
In [21]:
trained_model = create_model('../models/flower_classification_transfer_learning.pt')

backprop = Backprop(trained_model)

guided_gradients = backprop.calculate_gradients(input_, class_index, guided=True)
guided_max_gradients = backprop.calculate_gradients(input_, class_index, take_max=True, guided=True)

visualize(input_, guided_gradients, guided_max_gradients)